import random
from policies.comm_llama_policy import CommLlamaPolicy

class DebriefManager:
    def __init__(self, policies, max_debrief_rounds=2):
        self.policies = policies
        self.max_debrief_rounds = max_debrief_rounds
        self.debrief_order = list(self.policies.keys())
        # set who can participate in debriefing
        self.allowed_policies = [CommLlamaPolicy]
        self.ego_knowledges = []
        self.collective_knowledges = []
        self.agent_in_debrief = []
        for agent_id in self.debrief_order:
            policy = self.policies[agent_id]
            if type(policy) in self.allowed_policies:
                self.agent_in_debrief.append(agent_id)
    
    def learn(self):
        # debriefing by modifying the collective_knowledges
        for _ in range(self.max_debrief_rounds):
            for agent_id in self.debrief_order:
                policy = self.policies[agent_id]
                if type(policy) not in self.allowed_policies:
                    continue
                collective, ego = policy.debrief(self.collective_knowledges,
                                                 self.ego_knowledges,
                                                 self.agent_in_debrief,
                                                 batch_size=policy.batch_size//self.max_debrief_rounds,
                                                 last_round=False)
                self.collective_knowledges.append({agent_id:collective})
                self.ego_knowledges.append({agent_id:ego})

        # Last round summary
        for agent_id in self.debrief_order:
            policy = self.policies[agent_id]
            if type(policy) not in self.allowed_policies:
                continue
            collective, ego = policy.debrief(self.collective_knowledges,
                                             self.ego_knowledges,
                                             self.agent_in_debrief,
                                             batch_size=policy.batch_size//self.max_debrief_rounds,
                                             last_round=True)
        # internalize the knowledge into the policy
        for agent_id in self.debrief_order:
            policy = self.policies[agent_id]
            if type(policy) not in self.allowed_policies:
                continue
            policy.internalize()
            policy.replay_buffer.clear()
            policy.iteration += 1
            policy.save(policy.iteration)

        # reset for next learning
        self.reset()

    def reset(self):
        self.ego_knowledges = []
        self.collective_knowledges = []
        self.debrief_order = list(self.policies.keys())
        random.shuffle(self.debrief_order)
